# 此脚本用来测试行人搜索
# 第一步先通过detect生成目标检测结果
# 第二步通过和真实标签计算IOU为每个测试结果分配人物id
# 第三步根据测试结果和人物id保存图片
# 第四步 通过reid代码测试 reid 准确率
 
# 测试集和标签都是一个图片一个标签文件
 
from tqdm import tqdm
from PIL import Image
from PIL import ImageDraw
import os
from scipy.io import loadmat
image_folder = "/home/jia/fsdownload/PRW_myself/images/val"
test_txt_folder = "/home/jia/fsdownload/yolov3-master/runs/detect/exp/prw-yolov3/labels"
gt_mat_folder = "/home/fby/datasets/PRW-v16.04.20/annotations"
reid_output_folder = "/home/fby/datasets/reid_prw_test_yolo"
os.makedirs(os.path.join(reid_output_folder, "market1501/bounding_box_test"), exist_ok = True)
 
det_frame_dir = "/home/fby/datasets/PRW-v16.04.20/frame_test.mat"
 
empty_image_dir = "/home/fby/black.jpg"
def compute_iou(a, b):
    x1 = max(a[0], b[0])
    y1 = max(a[1], b[1])
    x2 = min(a[2], b[2])
    y2 = min(a[3], b[3])
    inter = max(0, x2 - x1) * max(0, y2 - y1)
    union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter
    return inter * 1.0 / union
 
# 遍历每一个检测结果
test_txts_fns = os.listdir(test_txt_folder)
for test_txt in tqdm(test_txts_fns):
 
    fgt_mat = loadmat(os.path.join(gt_mat_folder, test_txt.strip(".txt") + ".jpg.mat"))
    fgt_mat = fgt_mat[list(fgt_mat.keys())[-1]]
    scence_img = Image.open(os.path.join(image_folder, test_txt.strip(".txt") + ".jpg"))
 
    height = scence_img.height
    width = scence_img.width
 
    # 遍历该检测结果对应的gt中的每一个行人
    for gt_line in fgt_mat:
 
        gt_pid = int(gt_line[0])
        if gt_pid < 0:
            continue
        gt_box = [float(gt_line[1]), float(gt_line[2]), float(gt_line[3]), float(gt_line[4])]
        # 转为两点式
        gt_box = [gt_box[0],  gt_box[1], gt_box[0] + gt_box[2],  gt_box[1] + gt_box[3]]
        iou_thresh = min(0.5, (gt_box[2] * gt_box[3] * 1.0) /
                         ((gt_box[2] + 10) * (gt_box[3] + 10)))
 
        # 遍历该检测结果中的每一个行人
        # only set the first matched det as true positive
 
        # gt 是否能被找到
        gt_found = False
        ftest_txt = open(os.path.join(test_txt_folder, test_txt), "r", encoding="utf-8")
        for test_line in ftest_txt:
            test_line = test_line.split()
            test_class_id = int(test_line[0])
            test_box = [float(test_line[1]) * width, float(test_line[2]) * height, float(test_line[3]) * width,
                      float(test_line[4]) * height]
            # 转为两点式
            test_box = [test_box[0] - 0.5 * test_box[2], test_box[1] - 0.5 * test_box[3], test_box[0] + 0.5 * test_box[2],
                        test_box[1] + 0.5 * test_box[3]]
            scence_img_copy = scence_img.copy()
            # draw = ImageDraw.ImageDraw(scence_img_copy)
            #
            # draw.rectangle(((test_topleft[0], test_topleft[1]), (test_bottomright[0], test_bottomright[1])), fill=None, outline='red', width=4)
            # scence_img_copy.save("/home/fby/fuck.jpg")
 
            if compute_iou(test_box, gt_box) >= iou_thresh:
                # 存图
                crop = scence_img_copy.crop((int(test_box[0]), int(test_box[1]), int(test_box[2]), int(test_box[3])))
                crop.save(os.path.join(reid_output_folder, "market1501/bounding_box_test/" + str(gt_pid)+"_" + test_txt.strip(".txt") + ".jpg"))
                gt_found = True
                break
        ftest_txt.close()
        if gt_found == False:
            Image.open(empty_image_dir).save(os.path.join(reid_output_folder,
                                   "market1501/bounding_box_test/" + str(gt_pid) + "_" + test_txt.strip(".txt") + ".jpg"))
 
test_frames = loadmat(det_frame_dir)
test_frames = test_frames[list(test_frames.keys())[-1]]
for frame_name in test_frames:
    test_det_txt_fn = frame_name[0][0]+".txt"
    if not test_det_txt_fn in test_txts_fns:
        print("no det in "+ test_det_txt_fn + "\n")
        fgt_mat = loadmat(os.path.join(gt_mat_folder, test_det_txt_fn.strip(".txt") + ".jpg.mat"))
        fgt_mat = fgt_mat[list(fgt_mat.keys())[-1]]
        scence_img = Image.open(os.path.join(image_folder, test_det_txt_fn.strip(".txt") + ".jpg"))
 
        height = scence_img.height
        width = scence_img.width
 
        # 遍历该检测结果对应的gt中的每一个行人
        for gt_line in fgt_mat:
            gt_pid = int(gt_line[0])
            if gt_pid < 0:
                continue
            gt_box = [float(gt_line[1]), float(gt_line[2]), float(gt_line[3]), float(gt_line[4])]
            gt_topleft = [gt_box[0], gt_box[1]]
            gt_bottomright = [gt_box[0] + gt_box[2], gt_box[1] + gt_box[3]]
 
            Image.open(empty_image_dir).save(os.path.join(reid_output_folder,
                                                          "market1501/bounding_box_test/" + str(gt_pid) + "_" + test_det_txt_fn.strip(".txt")+".jpg"))
